# This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml 

import backbone
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from methods.meta_template import MetaTemplate
from methods.min_norm_solvers import MinNormSolver, gradient_normalizers

class MAML_MOML(MetaTemplate):
    def __init__(self, model_func,  n_way, n_support, approx = False):
        super(MAML_MOML, self).__init__( model_func,  n_way, n_support, change_way = False)

        self.loss_fn = nn.CrossEntropyLoss()
        self.classifier = backbone.Linear_fw(self.feat_dim, n_way)
        self.classifier.bias.data.fill_(0)
        
        self.n_task     = 3
        self.task_update_num = 5
        self.train_lr = 0.01
        self.approx = approx #first order approx.  

        self.weighting_mode = None

        #print([name for name,para in list(self.named_parameters())])

        
    def forward(self,x):
        out  = self.feature.forward(x)
        scores  = self.classifier.forward(out)
        return scores

    def set_forward(self,x, is_feature = False, robust = False ,LLmode = False):
        assert is_feature == False, 'MAML do not support fixed feature' 
        x = x.cuda()
        x_var = Variable(x)
        x_a_i = x_var[:,:self.n_support,:,:,:].contiguous().view( self.n_way* self.n_support, *x.size()[2:]) #support data 
        x_b_i = x_var[:,self.n_support:,:,:,:].contiguous().view( self.n_way* self.n_query,   *x.size()[2:]) #query data
        y_a_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_support ) )).cuda() #label for support data
        
        fast_parameters = list(self.parameters()) #the first gradient calcuated in line 45 is based on original weight
        for weight in self.parameters():
            weight.fast = None
        self.zero_grad()

        for task_step in range(self.task_update_num):
            scores = self.forward(x_a_i)
            # if self.LLmode == 'LLadv':
            #     if self.ATmode == 'A':
            #         LL_adv_input = self.test_FGSM(x_a_i, y_a_i, upper_limit, lower_limit, epsilon)
            #     elif self.ATmode == 'B':
            #         LL_adv_input = self.test_PGD(x_a_i, y_a_i, upper_limit, lower_limit, epsilon, step_num = 7)
            #     scores_robust = self.forward(LL_adv_input)
            #     set_loss = self.loss_fn( scores, y_a_i)  + 0.5 * self.loss_fn( scores_robust, y_a_i)
            # elif self.LLmode == 'LLsum':
            y_a_i = y_a_i.to(torch.int64)
            set_loss = self.loss_fn( scores, y_a_i)
            grad = torch.autograd.grad(set_loss, fast_parameters, create_graph=True) #build full graph support gradient of gradient
            if self.approx:
                grad = [ g.detach()  for g in grad ] #do not calculate gradient of gradient if using first order approximation
            fast_parameters = []
            for k, (name, weight) in enumerate(self.named_parameters()):
                #for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py 
                if weight.fast is None:
                    weight.fast = weight - self.train_lr * grad[k] #create weight.fast 
                else:
                    weight.fast = weight.fast - self.train_lr * grad[k] #create an updated weight.fast, note the '-' is not merely minus value, but to create a new weight.fast 
                fast_parameters.append(weight.fast) #gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts
        
        scores_robust = 0
        if robust:
            y_b_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_query   ) ))
            #adv_input = self.test_FGSM(x_b_i, y_b_i, upper_limit, lower_limit, epsilon)
            adv_input = self.test_PGD(x_b_i, y_b_i, step_num = 7)
            scores_robust = self.forward(adv_input)
            scores = self.forward(x_b_i)
            return scores , scores_robust
        else:
            scores = self.forward(x_b_i)
            return scores

    def set_forward_adaptation(self,x, is_feature = False): #overwrite parrent function
        raise ValueError('MAML performs further adapation simply by increasing task_upate_num')


    def set_forward_loss(self, x, require_rob = False):
        y_b_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_query   ) )).cuda()
        if require_rob:
            scores, scores_robust = self.set_forward(x, is_feature = False, robust = True)
            y_b_i = y_b_i.to(torch.int64)
            loss = self.loss_fn(scores, y_b_i)
            loss_robust = self.loss_fn(scores_robust, y_b_i)
            return loss, loss_robust
        else:
            scores = self.set_forward(x, is_feature = False)
            loss = self.loss_fn(scores, y_b_i)
            return loss

    def train_loop(self, epoch, train_loader, optimizer): #overwrite parrent function
        print_freq = 10
        avg_loss_acc=0
        avg_loss_rob=0
        task_count = 0
        loss_acc = []
        loss_rob = []
        grads = {}
        scale = {}
        optimizer.zero_grad()
        tasks = ['acc','rob']
        x_list = []
        grads['acc'] = []
        grads['rob'] = []

        if self.weighting_mode == 'ORG':
            for i, (x,_) in enumerate(train_loader):
                self.n_query = x.size(1) - self.n_support
                assert self.n_way  ==  x.size(0), "MAML do not support way change"
                loss = self.set_forward_loss(x)
                avg_loss_acc = avg_loss_acc+loss.item()
                loss_acc.append(loss)
                task_count += 1
                if task_count == self.n_task: #MAML update several tasks at one time
                    loss_q = torch.stack(loss_acc).sum(0) / (self.n_task)
                    loss_q.backward()
                    optimizer.step()
                    task_count = 0
                    loss_acc = []
                optimizer.zero_grad()
                if i % print_freq==0:
                    print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss_acc/float(i+1)))
        elif self.weighting_mode == 'SOML': 
            for i, (x,_) in enumerate(train_loader):
                self.n_query = x.size(1) - self.n_support
                assert self.n_way  ==  x.size(0), "MAML do not support way change"
                loss, loss_robust = self.set_forward_loss(x ,require_rob = True)
                avg_loss_acc = avg_loss_acc+loss.item()
                avg_loss_rob = avg_loss_rob+loss_robust.item()
                loss_acc.append(loss)
                loss_rob.append(loss_robust)
                task_count += 1
                alpha = 0.8
                if task_count == self.n_task: #MAML update several tasks at one time
                    loss_q = alpha * torch.stack(loss_acc).sum(0) + (1-alpha) * torch.stack(loss_rob).sum(0)
                    loss_q.backward()
                    optimizer.step()
                    task_count = 0
                    loss_acc = []
                    loss_rob = []
                optimizer.zero_grad()
                if i % print_freq==0:
                    print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} Loss ADV {:f}'.format(epoch, i, len(train_loader), avg_loss_acc/float(i+1), avg_loss_rob/float(i+1)))
                    #print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss_acc/float(i+1)))
        elif self.weighting_mode == 'MGDA':       
            for i, (x,_) in enumerate(train_loader):
                self.n_query = x.size(1) - self.n_support
                assert self.n_way  ==  x.size(0), "MAML do not support way change"
                loss, loss_robust = self.set_forward_loss(x, require_rob=True)
                avg_loss_acc = avg_loss_acc+loss.item()
                avg_loss_rob = avg_loss_rob+loss_robust.item()
                loss_acc.append(loss)
                loss_rob.append(loss_robust)
                task_count += 1
                if task_count == self.n_task:
                    loss_q_acc = torch.stack(loss_acc).sum(0)
                    loss_q_rob = torch.stack(loss_rob).sum(0)
                    grads['acc'] = torch.autograd.grad(loss_q_acc, self.parameters(), retain_graph=True)
                    grads['rob'] = torch.autograd.grad(loss_q_rob, self.parameters(), retain_graph=True)
                    loss_data = {'acc': loss_q_acc.item(), 'rob': loss_q_rob.item()}
                    gn = gradient_normalizers(grads, loss_data, normalization_type='loss+')
                    for t in tasks:
                        grads[t] = [gg / gn[t] for gg in grads[t]]

                    sol, _ = MinNormSolver.find_min_norm_element([grads[t] for t in tasks])
                    for j, t in enumerate(tasks):
                        scale[t] = float(sol[j])

                    loss_q = scale['acc'] * loss_q_acc + scale['rob'] * loss_q_rob
                    optimizer.zero_grad()
                    loss_q.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                    task_count = 0
                    loss_acc = []
                    loss_rob = []
                    grads['acc'] = []
                    grads['rob'] = []
                    x_list = []
                if i % print_freq==0:
                    print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} Loss ADV {:f}'.format(epoch, i, len(train_loader), avg_loss_acc/float(i+1), avg_loss_rob/float(i+1)))     
              
    def test_loop(self, test_loader, return_std = False): #overwrite parrent function
        correct =0
        count = 0
        acc_all = []
        acc_all2 = []
        iter_num = len(test_loader) 
        for i, (x,_) in enumerate(test_loader):
            self.n_query = x.size(1) - self.n_support
            assert self.n_way  ==  x.size(0), "MAML do not support way change"
            scores,scores2 = self.set_forward(x, robust = True)
            y_query = np.repeat(range( self.n_way ), self.n_query )
            topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
            topk_scores2, topk_labels2 = scores2.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()
            topk_ind2 = topk_labels2.cpu().numpy()
            top1_correct = np.sum(topk_ind[:,0] == y_query)
            top1_correct2 = np.sum(topk_ind2[:,0] == y_query)
            correct_this = float(top1_correct)
            correct_this2 = float(top1_correct2)
            count_this = len(y_query)
            acc_all.append(correct_this/ count_this *100 )
            acc_all2.append(correct_this2/ count_this *100 )
            
        acc_all  = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std  = np.std(acc_all)
        acc_all2  = np.asarray(acc_all2)
        acc_mean2 = np.mean(acc_all2)
        acc_std2  = np.std(acc_all2)
        B_score = 2 * (acc_all * acc_all2) / (acc_all + acc_all2)
        B_score2 = 2 * (np.mean(acc_all) * np.mean(acc_all2)) / (np.mean(acc_all) + np.mean(acc_all2))
        print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num,  acc_mean, 1.96* acc_std/np.sqrt(iter_num)))
        print('%d Test Rob = %4.2f%% +- %4.2f%%' %(iter_num,  acc_mean2, 1.96* acc_std2/np.sqrt(iter_num)))
        print('%d Test B Acc = %4.2f%% and %4.2f%% +- %4.2f%%' %(iter_num, B_score2,  np.mean(B_score), 1.96* np.std(B_score)/np.sqrt(iter_num)))
        return acc_mean, acc_mean2, np.mean(B_score)

    def clamp(self, X, lower_limit, upper_limit):
        return torch.max(torch.min(X, upper_limit), lower_limit)

    def test_PGD(self, x, y, step_num = 2):
        eps = 2/255 * torch.FloatTensor([1.0,1.0,1.0]).cuda()
        mean=  torch.FloatTensor([0.485, 0.456, 0.406]).cuda()
        std =  torch.FloatTensor([0.229, 0.224, 0.225]).cuda()
        epsilon = ((eps ) / std).reshape(3,1,1)
        upper_limit = torch.FloatTensor([2.2489, 2.4286, 2.6400]).reshape(3,1,1).cuda()
        lower_limit = torch.FloatTensor([-2.1179, -2.0357, -1.8044]).reshape(3,1,1).cuda()
        labels = Variable(y, requires_grad=False).cuda()
        images = Variable(x, requires_grad=True).cuda()
        
        step_size = 1.5 / step_num * epsilon

        for i in range(step_num):
            scores_test = self.forward(images)
            labels = labels.to(torch.int64)
            loss = self.loss_fn( scores_test, labels) 
            #loss.backward(retain_graph=True)
            grad = torch.autograd.grad(loss, images, 
                                    retain_graph=False, create_graph=False)[0]
            grad = grad.detach().data
            adv_images = images.detach().data + step_size * torch.sign(grad)
            delta = self.clamp(adv_images - x, -epsilon, epsilon)
            adv_images = self.clamp(x + delta, lower_limit, upper_limit)
            images = Variable(adv_images, requires_grad=True).cuda()
        return images


# def test_FGSM(self, x, y,upper_limit, lower_limit, epsilon):
#     labels = Variable(y, requires_grad=False).cuda()
#     images = Variable(x, requires_grad=True).cuda()

#     scores_test = self.forward(images)
#     loss = self.loss_fn( scores_test, labels) 
#     # loss.backward(retain_graph=True)
#     grad = torch.autograd.grad(loss, images, 
#                                 retain_graph=False, create_graph=False)[0]
                                
#     grad = grad.detach().data
#     delta = self.clamp(epsilon * torch.sign(grad), -epsilon, epsilon)
#     delta = self.clamp(delta, lower_limit.cuda() - images.data, upper_limit.cuda() - images.data)
#     adv_input = Variable(images.data + delta, requires_grad=False).cuda()
#     return adv_input
      
#        elif self.UL_model_train == 4 :
#
#            loss_acc = []
#            grads = {}
#            scale = {}
#            optimizer.zero_grad()
#            tasks = ['1','2','3','4']
#            x_list = []
#            grads['1'] = []
#            grads['2'] = []
#            grads['3'] = []
#            grads['4'] = []
#    
#            #train
#            for i, (x,_) in enumerate(train_loader):
#                self.n_query = x.size(1) - self.n_support
#                assert self.n_way  ==  x.size(0), "MAML do not support way change"
#                #forward 1 and save x
#                x_list.append(x)
#    
#                if task_count == 0:
#                    loss_acc.append(self.set_forward_loss(x))
#                    loss_q_1 = torch.stack(loss_acc).sum(0)
#                    loss_q_1.backward()
#                    for param in self.feature.parameters():
#                        if param.grad is not None:
#                            grads['1'].append(Variable(param.grad.data.clone(), requires_grad=False))
#    
#                    optimizer.zero_grad()
#                    loss_acc = []
#                    task_count += 1
#                
#                elif task_count == 1:
#                    loss_acc.append(self.set_forward_loss(x))
#                    loss_q_2 = torch.stack(loss_acc).sum(0)
#                    loss_q_2.backward()
#                    for param in self.feature.parameters():
#                        if param.grad is not None:
#                            grads['2'].append(Variable(param.grad.data.clone(), requires_grad=False))
#    
#                    optimizer.zero_grad()
#                    loss_acc = []
#                    task_count += 1
#    
#                elif task_count == 2:
#                    loss_acc.append(self.set_forward_loss(x))
#                    loss_q_3 = torch.stack(loss_acc).sum(0)
#                    #print(loss_q_3)
#                    
#                    loss_q_3.backward()
#                    for param in self.feature.parameters():
#                        if param.grad is not None:
#                            grads['3'].append(Variable(param.grad.data.clone(), requires_grad=False))
#    
#                    optimizer.zero_grad()
#                    loss_acc = []
#                    task_count += 1
#    
#                elif task_count == 3:
#                    loss_acc.append(self.set_forward_loss(x))
#                    loss_q_4 = torch.stack(loss_acc).sum(0)
#                    loss_q_4.backward()
#                    for param in self.feature.parameters():
#                        if param.grad is not None:
#                            grads['4'].append(Variable(param.grad.data.clone(), requires_grad=False))
#    
#                    optimizer.zero_grad()
#                    loss_acc = []
#                    task_count = 0
#    
#                    sol, _ = MinNormSolver.find_min_norm_element([grads[t] for t in tasks])
#                    for i, t in enumerate(tasks):
#                        scale[t] = float(sol[i])
#                    
#                    print(scale['1'],scale['2'],scale['3'],scale['4'])
#    
#                    for i in range(self.n_task):
#                        loss = self.set_forward_loss(x_list[i])
#                        loss_acc.append(loss)
#                        
#                    #print(loss_acc[0].sum(0))
#                    
#                    loss_q = loss_acc[0].sum(0) * scale['1'] + loss_acc[1].sum(0) * scale['2'] + loss_acc[2].sum(0) * scale['3']+ loss_acc[3].sum(0) * scale['4']
#                    
#                    loss_q.backward()
#                    optimizer.step()
#                    optimizer.zero_grad()
#                    grads['1'] = []
#                    grads['2'] = []
#                    grads['3'] = []
#                    grads['4'] = []
#                    x_list = []
#                    loss_acc = []
#    
#                if i % print_freq==0:
#                    print('Epoch {:d} | Batch {:d}/{:d} '.format(epoch, i, len(train_loader), ))
#                    #gn = gradient_normalizers(grads, loss_data, params['normalization_type'])
#                    #for t in tasks:
#                    #    for gr_i in range(len(grads[t])):
#                    #        grads[t][gr_i] = grads[t][gr_i] / gn[t]
#    
#                    # Frank-Wolfe iteration to compute scales.c


